import os
from abc import ABCMeta, abstractmethod, ABC
import re
from xml.etree import ElementTree
import copy

import numpy as np

from libsvm import svmutil

import vmaf
from vmaf.config import VmafConfig
from vmaf import ExternalProgramCaller, convert_pixel_format_ffmpeg2vmafexec
from vmaf.core.executor import Executor
from vmaf.core.niqe_train_test_model import NiqeTrainTestModel
from vmaf.core.result import Result
from vmaf.core.feature_assembler import FeatureAssembler
from vmaf.core.train_test_model import TrainTestModel, LibsvmNusvrTrainTestModel, \
    BootstrapLibsvmNusvrTrainTestModel
from vmaf.core.feature_extractor import SsimFeatureExtractor, \
    MsSsimFeatureExtractor, \
    VmafFeatureExtractor, PsnrFeatureExtractor, VmafIntegerFeatureExtractor, \
    FeatureExtractor
from vmaf.core.vmafexec_feature_extractor import CIEDE2000FeatureExtractor
from vmaf.tools.decorator import override

__copyright__ = "Copyright 2016-2020, Netflix, Inc."
__license__ = "BSD+Patent"

from vmaf.tools.misc import piecewise_linear_mapping


class QualityRunner(Executor):
    """
    QualityRunner takes in a list of assets, and run quality assessment on
    them, and return a list of corresponding results. A QualityRunner must
    specify a unique type and version combination (by the TYPE and VERSION
    attribute), so that the Result generated by it can be identified and
    stored by ResultStore class.

    There are two ways to create a derived class of QualityRunner:

    a) Call a command-line executable directly, very similar to what
    FeatureExtractor does. You must:
        1) Override TYPE and VERSION
        2) Override _generate_result(self, asset), which call a
        command-line executable and generate quality scores in a log file.
        3) Override _get_quality_scores(self, asset), which read the quality
        scores from the log file, and return the scores in a dictionary format.
        4) If necessary, override _remove_log(self, asset) if
        Executor._remove_log(self, asset) doesn't work for your purpose
        (sometimes the command-line executable could generate output log files
        in some different format, like multiple files).
    For an example, follow PsnrQualityRunner.

    b) Override the Executor._run_on_asset(self, asset) method to bypass the
    regular routine, but instead, in the method construct a FeatureAssembler
    (which calls a FeatureExtractor (or many) and assembles a list of features,
    followed by using a TrainTestModel (pre-trained somewhere else) to predict
    the final quality score. You must:
        1) Override TYPE and VERSION
        2) Override _run_on_asset(self, asset), which runs a FeatureAssembler,
        collect a feature vector, run TrainTestModel.predict() on it, and
        return a Result object (in this case, both Executor._run_on_asset(self,
        asset) and QualityRunner._read_result(self, asset) get bypassed.
        3) Override _remove_log(self, asset) by redirecting it to the
        FeatureAssembler.
        4) Override _remove_result(self, asset) by redirecting it to the
        FeatureAssembler.
    For an example, follow VmafQualityRunner.
    """

    __metaclass__ = ABCMeta

    @abstractmethod
    def _get_quality_scores(self, asset):
        raise NotImplementedError

    def _read_result(self, asset):
        result = {}
        result.update(self._get_quality_scores(asset))
        executor_id = self.executor_id
        return Result(asset, executor_id, result)

    @classmethod
    def get_scores_key(cls):
        return f"{cls.TYPE}_scores"

    @classmethod
    def get_score_key(cls):
        return f"{cls.TYPE}_score"


class QualityRunnerFromFeatureExtractor(QualityRunner):
    __metaclass__ = ABCMeta

    @abstractmethod
    def _get_feature_extractor_class(self):
        raise NotImplementedError

    @abstractmethod
    def _get_feature_key_for_score(self):
        raise NotImplementedError

    def _get_quality_scores(self, asset):
        raise NotImplementedError

    def _generate_result(self, asset):
        raise NotImplementedError

    def _get_feature_assembler_instance(self, asset):
        feature_dict = {
            self._get_feature_extractor_class().TYPE:
                self._get_feature_extractor_class().ATOM_FEATURES +
                getattr(self._get_feature_extractor_class(), 'DERIVED_ATOM_FEATURES', [])
        }

        feature_assembler = FeatureAssembler(
            feature_dict=feature_dict,
            feature_option_dict=None,
            assets=[asset],
            logger=self.logger,
            fifo_mode=self.fifo_mode,
            delete_workdir=self.delete_workdir,
            result_store=self.result_store,
            optional_dict=self.optional_dict,
            optional_dict2=self.optional_dict2,
            parallelize=False,  # parallelization already in a higher level
            save_workfiles=self.save_workfiles,
        )
        return feature_assembler

    @override(Executor)
    def _run_on_asset(self, asset):
        vmaf_fassembler = self._get_feature_assembler_instance(asset)
        vmaf_fassembler.run()
        feature_result = vmaf_fassembler.results[0]
        result_dict = {}
        result_dict.update(feature_result.result_dict.copy())  # add feature result
        result_dict[self.get_scores_key()] = feature_result.result_dict[
            self._get_feature_extractor_class().get_scores_key(self._get_feature_key_for_score())]  # add score
        del result_dict[self._get_feature_extractor_class().get_scores_key(self._get_feature_key_for_score())]  # delete redundant
        return Result(asset, self.executor_id, result_dict)

    @override(Executor)
    def _remove_result(self, asset):
        # override by redirecting it to the FeatureAssembler.

        vmaf_fassembler = self._get_feature_assembler_instance(asset)
        vmaf_fassembler.remove_results()


class PsnrQualityRunner(QualityRunnerFromFeatureExtractor, ABC):
    TYPE = 'PSNR'
    VERSION = '1.0'

    @override(QualityRunnerFromFeatureExtractor)
    def _get_feature_extractor_class(self):
        return PsnrFeatureExtractor

    @override(QualityRunnerFromFeatureExtractor)
    def _get_feature_key_for_score(self):
        return 'psnr'


class VmafLegacyQualityRunner(QualityRunner):
    TYPE = 'VMAF_legacy'

    # VERSION = '1.1'
    VERSION = 'F' + VmafFeatureExtractor.VERSION + '-1.1'

    FEATURE_ASSEMBLER_DICT = {'VMAF_feature': 'all'}

    FEATURE_RESCALE_DICT = {'VMAF_feature_vif_scores': (0.0, 1.0),
                            'VMAF_feature_adm_scores': (0.4, 1.0),
                            'VMAF_feature_ansnr_scores': (10.0, 50.0),
                            'VMAF_feature_motion_scores': (0.0, 20.0)}

    SVM_MODEL_FILE = VmafConfig.model_path("other_models", "model_V8a.model")

    # model_v8a.model is trained with customized feature order:
    SVM_MODEL_ORDERED_SCORES_KEYS = ['VMAF_feature_vif_scores',
                                     'VMAF_feature_adm_scores',
                                     'VMAF_feature_ansnr_scores',
                                     'VMAF_feature_motion_scores']

    def _get_quality_scores(self, asset):
        raise NotImplementedError

    def _generate_result(self, asset):
        raise NotImplementedError

    def _get_vmaf_feature_assembler_instance(self, asset):
        vmaf_fassembler = FeatureAssembler(
            feature_dict=self.FEATURE_ASSEMBLER_DICT,
            feature_option_dict=None,
            assets=[asset],
            logger=self.logger,
            fifo_mode=self.fifo_mode,
            delete_workdir=self.delete_workdir,
            result_store=self.result_store,
            optional_dict=None,
            optional_dict2=None,
            parallelize=False,  # parallelization already in a higher level
            save_workfiles=self.save_workfiles,
        )
        return vmaf_fassembler

    @override(Executor)
    def _run_on_asset(self, asset):
        # override Executor._run_on_asset(self, asset), which runs a
        # FeatureAssembler, collect a feature vector, run
        # TrainTestModel.predict() on it, and return a Result object
        # (in this case, both Executor._run_on_asset(self, asset) and
        # QualityRunner._read_result(self, asset) get bypassed.

        vmaf_fassembler = self._get_vmaf_feature_assembler_instance(asset)
        vmaf_fassembler.run()
        feature_result = vmaf_fassembler.results[0]

        # =====================================================================

        # SVR predict
        model = svmutil.svm_load_model(self.SVM_MODEL_FILE)

        ordered_scaled_scores_list = []
        for scores_key in self.SVM_MODEL_ORDERED_SCORES_KEYS:
            scaled_scores = self._rescale(feature_result[scores_key],
                                          self.FEATURE_RESCALE_DICT[scores_key])
            ordered_scaled_scores_list.append(scaled_scores)

        scores = []
        for score_vector in zip(*ordered_scaled_scores_list):
            vif, adm, ansnr, motion = score_vector
            xs = [[vif, adm, ansnr, motion]]
            score = svmutil.svm_predict([0], xs, model)[0][0]
            score = self._post_correction(motion, score)
            scores.append(score)

        result_dict = {}
        # add all feature result
        result_dict.update(feature_result.result_dict)
        # add quality score
        result_dict[self.get_scores_key()] = scores

        return Result(asset, self.executor_id, result_dict)

    def _post_correction(self, motion, score):
        # post-SVM correction
        if motion > 12.0:
            val = motion
            if val > 20.0:
                val = 20
            score *= ((val - 12) * 0.015 + 1)
        if score > 100.0:
            score = 100.0
        elif score < 0.0:
            score = 0.0
        return score

    @classmethod
    def _rescale(cls, vals, lower_upper_bound):
        lower_bound, upper_bound = lower_upper_bound
        vals = np.double(vals)
        vals = np.clip(vals, lower_bound, upper_bound)
        vals = (vals - lower_bound) / (upper_bound - lower_bound)
        return vals

    @override(Executor)
    def _remove_result(self, asset):
        # Override Executor._remove_result by redirecting it to the
        # FeatureAssembler.

        vmaf_fassembler = self._get_vmaf_feature_assembler_instance(asset)
        vmaf_fassembler.remove_results()


class VmafQualityRunnerModelMixin(object):

    def _load_model(self, asset):
        if self.optional_dict is not None \
                and 'model_filepath' in self.optional_dict \
                and self.optional_dict['model_filepath'] is not None:
            model_filepath = self.optional_dict['model_filepath']
        else:
            model_filepath = self.DEFAULT_MODEL_FILEPATH
        train_test_model_class = self.get_train_test_model_class()
        model = self._load_model_from_filepath(train_test_model_class, model_filepath, self.logger)
        return model

    @classmethod
    def _load_model_from_filepath(cls, train_test_model_class, model_filepath, logger):
        format = os.path.splitext(model_filepath)[1]
        supported_formats = ['.pkl', '.json']
        cls._assert_extension_format(supported_formats, format)
        try:
            if '.pkl' in format:
                model = train_test_model_class.from_file(model_filepath, logger, format='pkl')
            elif '.json' in format:
                model = train_test_model_class.from_file(model_filepath, logger, format='json', combined=True)
            else:
                assert False
        except AssertionError as e:
            raise AssertionError("File {filepath} may not be a valid model file for class {cls}: {e}".
                                 format(filepath=model_filepath, cls=train_test_model_class.__name__, e=e))
        return model

    @classmethod
    def _assert_extension_format(cls, supported_formats, format):
        """
        >>> supported_formats = ['.pkl', '.json']
        >>> VmafQualityRunner._assert_extension_format(supported_formats, '.pkl')
        >>> VmafQualityRunner._assert_extension_format(supported_formats, '.pkl_2160')
        >>> VmafQualityRunner._assert_extension_format(supported_formats, '.pkkl')
        Traceback (most recent call last):
        ...
        AssertionError: VmafQualityRunner supports .pkl or .json model file, but the file format is: .pkkl
        >>> VmafQualityRunner._assert_extension_format(supported_formats, '.json')
        >>> VmafQualityRunner._assert_extension_format(supported_formats, '.json_720')
        >>> VmafQualityRunner._assert_extension_format(supported_formats, '.jsonn')
        >>> VmafQualityRunner._assert_extension_format(supported_formats, '.jsson')
        Traceback (most recent call last):
        ...
        AssertionError: VmafQualityRunner supports .pkl or .json model file, but the file format is: .jsson
        """
        for supported_format in supported_formats:
            if supported_format in format:
                break
        else:
            assert False, \
                f'{cls.__name__} supports .pkl or .json model file, but the file format is: {format}'


class VmafQualityRunner(VmafQualityRunnerModelMixin, QualityRunner):
    TYPE = 'VMAF'

    # VERSION = '0.1' # using model nflxall_vmafv1.pkl, VmafFeatureExtractor VERSION 0.1
    # DEFAULT_MODEL_FILEPATH = VmafConfig.model_path("nflxall_vmafv1.pkl") # trained with resource/param/vmaf_v1.py on private/resource/dataset/NFLX_dataset.py (30 subjects)

    # VERSION = '0.2' # using model nflxall_vmafv2.pkl, VmafFeatureExtractor VERSION 0.2.1
    # DEFAULT_MODEL_FILEPATH = VmafConfig.model_path("nflxall_vmafv2.pkl") # trained with resource/param/vmaf_v2.py on private/resource/dataset/NFLX_dataset.py (30 subjects)

    # VERSION = '0.3' # using model nflxall_vmafv3.pkl, VmafFeatureExtractor VERSION 0.2.1
    # DEFAULT_MODEL_FILEPATH = VmafConfig.model_path("nflxall_vmafv3.pkl") # trained with resource/param/vmaf_v3.py on private/resource/dataset/NFLX_dataset.py (30 subjects)

    # VERSION = '0.3.1' # using model nflxall_vmafv3.pkl, VmafFeatureExtractor VERSION 0.2.1, NFLX_dataset with 26 subjects (last 4 outliers removed)
    # DEFAULT_MODEL_FILEPATH = VmafConfig.model_path("nflxall_vmafv3a.pkl") # trained with resource/param/vmaf_v3.py on private/resource/dataset/NFLX_dataset.py (26 subjects)

    # VERSION = '0.3.2'  # using model nflxall_vmafv4.pkl, VmafFeatureExtractor VERSION 0.2.2, NFLX_dataset with 26 subjects (last 4 outliers removed)
    # ALGO_VERSION = 0
    # DEFAULT_MODEL_FILEPATH = VmafConfig.model_path("nflxall_vmafv4.pkl")  # trained with resource/param/vmaf_v4.py on private/resource/dataset/NFLX_dataset.py (26 subjects)

    # # using model vmaf_float_v0.6.1.pkl. VmafFeatureExtractor VERSION 0.2.4b.
    # VERSION = 'F' + VmafFeatureExtractor.VERSION + '-0.6.1'
    # ALGO_VERSION = 2

    # default model vmaf_v0.6.1.json
    VERSION = 'F' + VmafFeatureExtractor.VERSION + 'int' + '-0.6.1'
    ALGO_VERSION = 4

    # trained with resource/param/vmaf_v6.py on private/user/zli/resource/dataset/dataset/derived/vmafplusstudy_laptop_raw_generalandcornercase.py, MLER, y=x+17
    # modified from vmaf_float_v0.6.1.pkl to use integer features
    DEFAULT_MODEL_FILEPATH = vmaf.model_path("vmaf_v0.6.1.json")

    DEFAULT_FEATURE_DICT = {'VMAF_feature': ['vif', 'adm', 'motion', 'ansnr']}  # for backward-compatible with older model only

    def _get_quality_scores(self, asset):
        raise NotImplementedError

    def _generate_result(self, asset):
        raise NotImplementedError

    def _get_vmaf_feature_assembler_instance(self, asset):

        model = self._load_model(asset)

        # load TrainTestModel to retrieve its 'feature_dict' extra info
        feature_dict = model.get_appended_info('feature_dict')
        if feature_dict is None:
            feature_dict = self.DEFAULT_FEATURE_DICT

        # create feature_option_dict:
        # feature_opts_dicts in model is per-atom feature; need to map it to
        # per-aggregate feature (if inconsistent, raise error)
        atom_feature_opts_dicts = model.get_appended_info('feature_opts_dicts')
        if atom_feature_opts_dicts is None:
            aggr_feature_opts_dict = None
        else:
            atom_feature_names = model.get_appended_info('feature_names')
            assert atom_feature_names is not None
            assert feature_dict is not None
            aggr_feature_opts_dict = self._get_aggr_feature_opts_dict_from_atom_feature_opts_dicts(
                feature_dict, atom_feature_names, atom_feature_opts_dicts)

        vmaf_fassembler = FeatureAssembler(
            feature_dict=feature_dict,
            feature_option_dict=aggr_feature_opts_dict,
            assets=[asset],
            logger=self.logger,
            fifo_mode=self.fifo_mode,
            delete_workdir=self.delete_workdir,
            result_store=self.result_store,
            optional_dict=self.optional_dict,
            optional_dict2=self.optional_dict2,
            parallelize=False,  # parallelization already in a higher level
            save_workfiles=self.save_workfiles,
        )
        return vmaf_fassembler

    @staticmethod
    def _get_aggr_feature_opts_dict_from_atom_feature_opts_dicts(feature_dict,
                                                                 atom_feature_names,
                                                                 atom_feature_opts_dicts):

        """
        >>> atom_feature_names = ['VMAF_integer_feature_adm2_score', 'VMAF_integer_feature_motion2_score', 'VMAF_integer_feature_vif_scale0_score', 'VMAF_integer_feature_vif_scale1_score', 'VMAF_integer_feature_vif_scale2_score', 'VMAF_integer_feature_vif_scale3_score']
        >>> atom_feature_opts_dicts = [{'adm_enhn_gain_limit': 1.0}, {}, {'vif_enhn_gain_limit': 1.0}, {'vif_enhn_gain_limit': 1.0}, {'vif_enhn_gain_limit': 1.0}, {'vif_enhn_gain_limit': 1.0}]
        >>> feature_dict = {'VMAF_integer_feature': ['vif_scale0', 'vif_scale1', 'vif_scale2', 'vif_scale3', 'adm2', 'motion2']}
        >>> VmafQualityRunner._get_aggr_feature_opts_dict_from_atom_feature_opts_dicts(feature_dict, atom_feature_names, atom_feature_opts_dicts)
        {'VMAF_integer_feature': {'vif_enhn_gain_limit': 1.0, 'adm_enhn_gain_limit': 1.0}}
        >>> atom_feature_opts_dicts2 = [{'adm_enhn_gain_limit': 1.1}, {}, {'vif_enhn_gain_limit': 1.0}, {'vif_enhn_gain_limit': 1.0}, {'vif_enhn_gain_limit': 1.0}, {'vif_enhn_gain_limit': 1.0}]
        >>> VmafQualityRunner._get_aggr_feature_opts_dict_from_atom_feature_opts_dicts(feature_dict, atom_feature_names, atom_feature_opts_dicts2)
        {'VMAF_integer_feature': {'vif_enhn_gain_limit': 1.0, 'adm_enhn_gain_limit': 1.1}}
        >>> atom_feature_opts_dicts3 = [{'adm_enhn_gain_limit': 1.1}, {}, {'vif_enhn_gain_limit': 1.0}, {'vif_enhn_gain_limit': 1.0}, {'vif_enhn_gain_limit': 1.4}, {'vif_enhn_gain_limit': 1.0}]
        >>> VmafQualityRunner._get_aggr_feature_opts_dict_from_atom_feature_opts_dicts(feature_dict, atom_feature_names, atom_feature_opts_dicts3)
        Traceback (most recent call last):
        ...
        AssertionError: feature_opts_dicts are inconsistent for atom features belong to the same aggregate features: 1.0 vs. 1.4
        """

        assert len(atom_feature_opts_dicts) == len(atom_feature_names)
        d_fname_fopts = dict(zip(atom_feature_names, atom_feature_opts_dicts))
        aggr_feature_opts_dict = {}
        for aggr_feature in feature_dict:
            fextractor_class = FeatureExtractor.find_subclass(aggr_feature)
            for atom_feature in feature_dict[aggr_feature]:
                atom_feature_full = fextractor_class.get_score_key(atom_feature)
                if atom_feature_full in d_fname_fopts:
                    if aggr_feature not in aggr_feature_opts_dict:
                        aggr_feature_opts_dict[aggr_feature] = dict()

                    for opt in d_fname_fopts[atom_feature_full]:
                        if opt not in aggr_feature_opts_dict[aggr_feature]:
                            aggr_feature_opts_dict[aggr_feature][opt] = \
                            d_fname_fopts[atom_feature_full][opt]
                        else:
                            assert aggr_feature_opts_dict[aggr_feature][opt] == \
                                   d_fname_fopts[atom_feature_full][opt], \
                                'feature_opts_dicts are inconsistent for atom features belong to the same aggregate features: {} vs. {}'.format(
                                    aggr_feature_opts_dict[aggr_feature][opt],
                                    d_fname_fopts[atom_feature_full][opt]
                                )
        return aggr_feature_opts_dict

    @override(Executor)
    def _run_on_asset(self, asset):
        # Override Executor._run_on_asset(self, asset), which runs a
        # FeatureAssembler, collect a feature vector, run
        # TrainTestModel.predict() on it, and return a Result object
        # (in this case, both Executor._run_on_asset(self, asset) and
        # QualityRunner._read_result(self, asset) get bypassed.
        vmaf_fassembler = self._get_vmaf_feature_assembler_instance(asset)
        vmaf_fassembler.run()
        feature_result = vmaf_fassembler.results[0]
        model = self._load_model(asset)
        xs = model.get_per_unit_xs_from_a_result(feature_result)

        if self.optional_dict is not None and 'disable_clip_score' in self.optional_dict:
            disable_clip_score = self.optional_dict['disable_clip_score']
            assert isinstance(disable_clip_score, bool)
        else:
            disable_clip_score = None

        if self.optional_dict is not None and 'enable_transform_score' in self.optional_dict:
            enable_transform_score = self.optional_dict['enable_transform_score']
            assert isinstance(enable_transform_score, bool)
        else:
            enable_transform_score = None

        more = dict()
        if disable_clip_score is not None:
            more['disable_clip_score'] = disable_clip_score
        if enable_transform_score is not None:
            more['enable_transform_score'] = enable_transform_score

        pred_result = self.predict_with_model(model, xs, **more)
        result_dict = self._populate_result_dict(feature_result, pred_result)
        return Result(asset, self.executor_id, result_dict)

    def _populate_result_dict(self, feature_result, pred_result):
        result_dict = {}
        result_dict.update(feature_result.result_dict)  # add feature result
        result_dict[self.get_scores_key()] = pred_result['ys_pred']  # add quality score
        return result_dict

    @classmethod
    def predict_with_model(cls, model, xs, **kwargs):
        ys_pred = model.predict(xs)['ys_label_pred']
        do_transform_score = cls._do_transform_score(model, kwargs)
        if do_transform_score:
            ys_pred = cls.transform_score(model, ys_pred)
        else:
            pass
        if 'disable_clip_score' in kwargs and kwargs['disable_clip_score'] is True:
            pass
        else:
            ys_pred = cls.clip_score(model, ys_pred)
        return {'ys_pred': ys_pred}

    @staticmethod
    def _do_transform_score(model, kwargs):
        model_flag = None
        transform_dict = model.get_appended_info('score_transform')
        if transform_dict is not None and 'enabled' in transform_dict:
            assert isinstance(transform_dict['enabled'], bool)
            model_flag = transform_dict['enabled']

        kwargs_flag = None
        if 'enable_transform_score' in kwargs:
            assert isinstance(kwargs['enable_transform_score'], bool)
            kwargs_flag = kwargs['enable_transform_score']

        if model_flag is None and kwargs_flag is not None:
            return kwargs_flag
        elif model_flag is not None and kwargs_flag is None:
            return model_flag
        elif model_flag is None and kwargs_flag is None:
            return False
        else:
            # as long as one is True, transform is enabled
            # this is consistent behavior with VmafexecQualityRunner (libvmaf)
            return model_flag or kwargs_flag

    @staticmethod
    def set_transform_score(model, score_transform):
        model.append_info('score_transform', score_transform)

    @staticmethod
    def set_clip_score(model, score_clip):
        model.append_info('score_clip', score_clip)

    @staticmethod
    def transform_score(model, y_in):
        """
        Transform final quality score in the following optional steps (in this
        order):
        1) polynomial mapping. e.g. {'p0': 1, 'p1': 1, 'p2': 0.5} means
        transform through 1 + x + 0.5 * x^2. For now, only support polynomail
        up to 2nd-order.
        2) piecewise-linear mapping, where the change points are defined in
        'knots', in the form of [[x0, y0], [x1, y1], ...].
        3) rectification, supporting 'out_lte_in' (output is less than or equal
        to input) and 'out_gte_in' (output is greater than or equal to input).
        """

        transform_dict = model.get_appended_info('score_transform')

        if transform_dict is None:
            return y_in

        # polynomial mapping
        y_stage = np.copy(y_in)
        if 'p0' in transform_dict or 'p1' in transform_dict or 'p2' in transform_dict:
            y_out = np.zeros(y_stage.shape)
            if 'p0' in transform_dict:
                y_out += transform_dict['p0']
            if 'p1' in transform_dict:
                y_out += transform_dict['p1'] * y_stage
            if 'p2' in transform_dict:
                y_out += transform_dict['p2'] * y_stage * y_stage
        else:
            y_out = y_stage

        # piecewise-linear mapping
        y_stage = np.copy(y_out)
        if 'knots' in transform_dict:
            y_out = piecewise_linear_mapping(y_stage, transform_dict['knots'])
        else:
            y_out = y_stage

        # rectification
        if 'out_lte_in' in transform_dict and transform_dict['out_lte_in'] == 'true':
            # output must be less than or equal to input
            y_out = np.minimum(y_out, y_in)
        if 'out_gte_in' in transform_dict and transform_dict['out_gte_in'] == 'true':
            # output must be greater than or equal to input
            y_out = np.maximum(y_out, y_in)

        return y_out

    @staticmethod
    def clip_score(model, ys_pred):
        """
        Do post processing: clip final quality score within e.g. [0, 100]
        """
        score_clip = model.get_appended_info('score_clip')
        if score_clip is not None:
            lb, ub = score_clip
            ys_pred = np.clip(ys_pred, lb, ub)

        return ys_pred

    def get_train_test_model_class(self):
        return LibsvmNusvrTrainTestModel

    @override(Executor)
    def _remove_result(self, asset):
        # Override Executor._remove_result(self, asset) by redirecting it to the
        # FeatureAssembler.

        vmaf_fassembler = self._get_vmaf_feature_assembler_instance(asset)
        vmaf_fassembler.remove_results()


class VmafnegQualityRunner(VmafQualityRunner):
    TYPE = 'VMAFNEG'
    DEFAULT_MODEL_FILEPATH = vmaf.model_path('vmaf_v0.6.1neg.json')


class EnsembleVmafQualityRunner(VmafQualityRunner):
    TYPE = 'EnsembleVMAF'

    VERSION = '{}-Ensemble'.format(VmafQualityRunner.VERSION)

    DEFAULT_MODEL_FILEPATH = [VmafConfig.model_path("vmaf_float_v0.6.1.pkl"), VmafConfig.model_path("vmaf_float_v0.6.1.pkl")]

    # this now needs to become a list
    DEFAULT_FEATURE_DICT = [{'VMAF_feature': ['vif', 'adm', 'motion', 'ansnr']}, {'VMAF_feature': ['vif', 'adm', 'motion', 'ansnr']}]

    def _populate_result_dict(self, feature_result, pred_result, result_dict):
        result_dict.update(feature_result.result_dict)  # add feature result
        result_dict[self.get_scores_key()] = pred_result['ys_pred']  # add quality score
        return result_dict

    def _get_ensemblevmaf_feature_assembler_instance(self, asset):

        # load TrainTestModel only to retrieve its 'feature_dict' extra info
        ensem_models = self._load_model(asset)
        ensemblevmaf_fassemblers = []

        for model_ind, model_now in enumerate(ensem_models):

            feature_dict = model_now.get_appended_info('feature_dict')
            if feature_dict is None:
                feature_dict = self.DEFAULT_FEATURE_DICT[model_ind]

            ensemblevmaf_fassembler = FeatureAssembler(
                feature_dict=feature_dict,
                feature_option_dict=None,
                assets=[asset],
                logger=self.logger,
                fifo_mode=self.fifo_mode,
                delete_workdir=self.delete_workdir,
                result_store=self.result_store,
                optional_dict=None,  # WARNING: feature param not passed
                optional_dict2=None,
                parallelize=False,  # parallelization already in a higher level
            )

            ensemblevmaf_fassemblers.append(ensemblevmaf_fassembler)

        return ensemblevmaf_fassemblers

    def _get_Nframes(self, pred_result):

        return len(pred_result['ys_pred'])

    @override(Executor)
    def _run_on_asset(self, asset):
        # Override Executor._run_on_asset(self, asset), which runs a
        # FeatureAssembler, collect a feature vector, run
        # TrainTestModel.predict() on it, and return a Result object
        # (in this case, both Executor._run_on_asset(self, asset) and
        # QualityRunner._read_result(self, asset) get bypassed.
        ensemblevmaf_fassemblers = self._get_ensemblevmaf_feature_assembler_instance(asset)

        # each model is associated with a Feature Assembler
        Nmodels = len(ensemblevmaf_fassemblers)
        pred_result_ensem_models = []
        result_dict = {}

        for model_ind in range(Nmodels):

            evmaf_fassembler = ensemblevmaf_fassemblers[model_ind]
            evmaf_fassembler.run()
            feature_result = evmaf_fassembler.results[0]
            model = self._load_model(asset)[model_ind]
            xs = model.get_per_unit_xs_from_a_result(feature_result)

            if self.optional_dict is not None and 'disable_clip_score' in self.optional_dict:
                disable_clip_score = self.optional_dict['disable_clip_score']
            else:
                disable_clip_score = False

            if self.optional_dict is not None and 'enable_transform_score' in self.optional_dict:
                enable_transform_score = self.optional_dict['enable_transform_score']
            else:
                enable_transform_score = False

            pred_result = self.predict_with_model(model, xs,
                                                  disable_clip_score=disable_clip_score,
                                                  enable_transform_score=enable_transform_score)
            result_dict = self._populate_result_dict(feature_result, pred_result, result_dict)
            pred_result_ensem_models.append(pred_result)

        assert Nmodels > 0

        Nframes = self._get_Nframes(pred_result)

        all_model_scores = np.zeros((Nmodels, Nframes))
        all_model_score_names = self.ensemblevmaf_get_scores_key(Nmodels)
        for model_ind in range(Nmodels):
            result_dict[all_model_score_names[model_ind]] = pred_result_ensem_models[model_ind]['ys_pred']  # add quality score
            all_model_scores[model_ind, :] = pred_result_ensem_models[model_ind]['ys_pred']

        # perform prediction averaging (simple average for now)
        pred_result_ensem_models_aggregate = np.mean(all_model_scores, axis=0)

        # write results
        result_dict[self.get_scores_key()] = pred_result_ensem_models_aggregate

        return Result(asset, self.executor_id, result_dict)

    def ensemblevmaf_get_scores_key(self, Nmodels):
        scores_name_list = []
        for model_ind in range(Nmodels):
            scores_name_list.append(self.TYPE + '_model_' + str(model_ind) + '_scores')
        return scores_name_list

    def _load_model(self, asset):
        if self.optional_dict is not None \
                and 'model_filepath' in self.optional_dict \
                and self.optional_dict['model_filepath'] is not None:
            model_filepath = self.optional_dict['model_filepath']
        else:
            model_filepath = self.DEFAULT_MODEL_FILEPATH

        model = []
        for model_filepath_part in model_filepath:
            model_ = self._load_model_from_filepath(TrainTestModel, model_filepath_part, self.logger)
            model.append(model_)
        return model

    @override(Executor)
    def _remove_result(self, asset):
        # Override Executor._remove_result(self, asset) by redirecting it to the
        # FeatureAssembler.

        ensemblevmaf_fassemblers = self._get_ensemblevmaf_feature_assembler_instance(asset)
        for ensemblevmaf_fassembler in ensemblevmaf_fassemblers:
            ensemblevmaf_fassembler.remove_results()


class VmafPhoneQualityRunner(VmafQualityRunner):
    TYPE = 'VMAF_Phone'

    VERSION = '{}-phone'.format(VmafQualityRunner.VERSION)

    def _assert_args(self):
        super(VmafPhoneQualityRunner, self)._assert_args()

        if self.optional_dict is not None:
            assert 'enable_transform_score' not in self.optional_dict, \
                'Cannot specify enable_transform_score option in {cls}.'.format(cls=self.__class__.__name__)

    @staticmethod
    @override(VmafQualityRunner)
    def _do_transform_score(model, kwargs):
        return True


class FeatureDiscoveryMixin(object):

    @staticmethod
    def _discover_feature_exact(frame, i_feature, feature_, feature_origin,
                                feature_scores, feature_nicknames):
        feature_found = False
        for feature_fullname in frame.attrib:
            if feature_ == feature_fullname:
                feature_scores[i_feature].append(
                    float(frame.attrib[feature_fullname]))
                if feature_nicknames[i_feature] is None:
                    feature_nicknames[i_feature] = feature_origin
                else:
                    assert feature_nicknames[i_feature] == feature_origin
                feature_found = True
                break
        return feature_found

    @staticmethod
    def _discover_feature_wildcard(frame, i_feature, feature_prefix, feature_origin,
                                   feature_scores, feature_nicknames):
        feature_found = False
        for feature_fullname in frame.attrib:
            if feature_fullname.startswith(feature_prefix):
                feature_scores[i_feature].append(
                    float(frame.attrib[feature_fullname]))
                feature_suffix = feature_fullname[len(feature_prefix):]
                feature_nickname = feature_origin + '_' + feature_suffix
                if feature_nicknames[i_feature] is None:
                    feature_nicknames[i_feature] = feature_nickname
                else:
                    assert feature_nicknames[i_feature] == feature_nickname
                feature_found = True
                break
        return feature_found


class VmafossExecQualityRunner(QualityRunner, FeatureDiscoveryMixin):
    TYPE = 'VMAFOSSEXEC'

    # VERSION = '0.3'
    # DEFAULT_MODEL_FILEPATH_DOTMODEL = VmafConfig.model_path("nflxall_vmafv3.pkl.model")

    # VERSION = '0.3.1'
    # DEFAULT_MODEL_FILEPATH_DOTMODEL = VmafConfig.model_path("nflxall_vmafv3a.pkl.model")

    # VERSION = '0.3.2'
    # ALGO_VERSION = 0
    # # DEFAULT_MODEL_FILEPATH_DOTMODEL = VmafConfig.model_path("nflxall_vmafv4.pkl.model")
    # DEFAULT_MODEL_FILEPATH = VmafConfig.model_path("nflxall_vmafv4.pkl")

    VERSION = 'F' + VmafFeatureExtractor.VERSION + '-0.6.1'
    ALGO_VERSION = 2

    # trained with resource/param/vmaf_v6.py on private/user/zli/resource/dataset/dataset/derived/vmafplusstudy_laptop_raw_generalandcornercase.py, MLER, y=x+17
    DEFAULT_MODEL_FILEPATH = vmaf.model_path("vmaf_v0.6.1.json")

    FEATURES = ['adm2', 'adm_scale0', 'adm_scale1', 'adm_scale2', 'adm_scale3',
                'motion', 'vif_scale0', 'vif_scale1', 'vif_scale2',
                'vif_scale3', 'vif', 'psnr', 'ssim', 'ms_ssim', 'motion2',
                'vmaf_bagging', 'vmaf_stddev', 'vmaf_ci_p95_lo', 'vmaf_ci_p95_hi']

    @classmethod
    def get_feature_scores_key(cls, atom_feature):
        return "{type}_{atom_feature}_scores".format(
            type=cls.TYPE, atom_feature=atom_feature)

    def _generate_result(self, asset):
        # routine to call the command-line executable and generate quality
        # scores in the log file.

        log_file_path = self._get_log_file_path(asset)

        if self.optional_dict is not None \
                and 'model_filepath' in self.optional_dict \
                and self.optional_dict['model_filepath'] is not None:
            model_filepath = self.optional_dict['model_filepath']
        else:
            model_filepath = self.DEFAULT_MODEL_FILEPATH

        if self.optional_dict is not None and 'disable_clip_score' in self.optional_dict:
            disable_clip_score = self.optional_dict['disable_clip_score']
        else:
            disable_clip_score = False

        if self.optional_dict is not None and 'enable_transform_score' in self.optional_dict:
            enable_transform_score = self.optional_dict['enable_transform_score']
        else:
            enable_transform_score = False

        if self.optional_dict is not None and 'phone_model' in self.optional_dict:
            phone_model = self.optional_dict['phone_model']
        else:
            phone_model = False

        if self.optional_dict is not None and 'disable_avx' in self.optional_dict:
            disable_avx = self.optional_dict['disable_avx']
        else:
            disable_avx = False

        if self.optional_dict is not None and 'thread' in self.optional_dict:
            n_thread = self.optional_dict['thread']
        else:
            n_thread = 0

        if self.optional_dict is not None and 'subsample' in self.optional_dict:
            n_subsample = self.optional_dict['subsample']
        else:
            n_subsample = 1

        if self.optional_dict is not None and 'psnr' in self.optional_dict:
            psnr = self.optional_dict['psnr']
        else:
            psnr = True
        assert isinstance(psnr, bool)

        if self.optional_dict is not None and 'ssim' in self.optional_dict:
            ssim = self.optional_dict['ssim']
        else:
            ssim = True
        assert isinstance(ssim, bool)

        if self.optional_dict is not None and 'ms_ssim' in self.optional_dict:
            ms_ssim = self.optional_dict['ms_ssim']
        else:
            ms_ssim = True
        assert isinstance(ms_ssim, bool)

        if self.optional_dict is not None and 'ci' in self.optional_dict:
            ci = self.optional_dict['ci']
        else:
            ci = False

        quality_width, quality_height = asset.quality_width_height

        fmt = self._get_workfile_yuv_type(asset)
        w = quality_width
        h = quality_height
        ref_path = asset.ref_procfile_path
        dis_path = asset.dis_procfile_path
        model = model_filepath
        exe = self._get_exec()
        logger = self.logger

        ExternalProgramCaller.call_vmafossexec(fmt, w, h, ref_path, dis_path, model, log_file_path,
                                               disable_clip_score, enable_transform_score,
                                               phone_model, disable_avx, n_thread, n_subsample,
                                               psnr, ssim, ms_ssim, ci, exe, logger)

    def _get_exec(self):
        return None  # signaling default

    def _get_quality_scores(self, asset):
        # routine to read the quality scores from the log file, and return
        # the scores in a dictionary format.
        log_file_path = self._get_log_file_path(asset)
        tree = ElementTree.parse(log_file_path)
        root = tree.getroot()
        scores = []

        augmented_features = copy.copy(self.FEATURES)

        # check for per bootstrap model score
        for feature in root.findall('frames/frame')[0].attrib.keys():
            if re.match(r"vmaf_[0-9]+", feature):
                augmented_features.append(feature)

        feature_scores = [[] for _ in augmented_features]
        feature_nicknames = [None for _ in augmented_features]

        for frame in root.findall('frames/frame'):
            scores.append(float(frame.attrib['vmaf']))
            for i_feature, feature in enumerate(augmented_features):

                # first look for exact match integer_xxx
                feature_found = self._discover_feature_exact(
                    frame, i_feature,
                    'integer_' + feature, feature,
                    feature_scores, feature_nicknames)

                if feature_found:
                    continue

                # look for exact match xxx
                feature_found = self._discover_feature_exact(
                    frame, i_feature,
                    feature, feature,
                    feature_scores, feature_nicknames)

                if feature_found:
                    continue

                # wildcard discovery: look for integer_xxx_*
                feature_found = self._discover_feature_wildcard(
                    frame, i_feature,
                    'integer_' + feature + '_', feature,
                    feature_scores, feature_nicknames)

                if feature_found:
                    continue

                # wildcard discovery: look for xxx_*
                feature_found = self._discover_feature_wildcard(
                    frame, i_feature,
                    feature + '_', feature,
                    feature_scores, feature_nicknames)

        assert len(scores) != 0
        quality_result = {
            self.get_scores_key(): scores,
        }
        for i_feature, feature in enumerate(augmented_features):
            if len(feature_scores[i_feature]) != 0:
                assert feature_nicknames[i_feature] is not None
                quality_result[self.get_feature_scores_key(feature_nicknames[i_feature])] = feature_scores[i_feature]
        return quality_result


class SsimQualityRunner(QualityRunnerFromFeatureExtractor, ABC):
    TYPE = 'SSIM'
    VERSION = '1.0'

    @override(QualityRunnerFromFeatureExtractor)
    def _get_feature_extractor_class(self):
        return SsimFeatureExtractor

    @override(QualityRunnerFromFeatureExtractor)
    def _get_feature_key_for_score(self):
        return 'ssim'


class MsSsimQualityRunner(QualityRunnerFromFeatureExtractor, ABC):
    TYPE = 'MS_SSIM'
    VERSION = '1.0'

    @override(QualityRunnerFromFeatureExtractor)
    def _get_feature_extractor_class(self):
        return MsSsimFeatureExtractor

    @override(QualityRunnerFromFeatureExtractor)
    def _get_feature_key_for_score(self):
        return 'ms_ssim'


class Ciede2000QualityRunner(QualityRunnerFromFeatureExtractor, ABC):
    TYPE = 'CIEDE2000'
    VERSION = '1.0'

    @override(QualityRunnerFromFeatureExtractor)
    def _get_feature_extractor_class(self):
        return CIEDE2000FeatureExtractor

    @override(QualityRunnerFromFeatureExtractor)
    def _get_feature_key_for_score(self):
        return 'ciede2000'


class VmafSingleFeatureQualityRunner(QualityRunner):
    __metaclass__ = ABCMeta

    # VERSION = 'F{}-0'.format(VmafFeatureExtractor.VERSION)
    VERSION = f'F{VmafIntegerFeatureExtractor.VERSION}int-0'

    @property
    @abstractmethod
    def FEATURE_NAME(self):
        raise NotImplementedError

    def _get_quality_scores(self, asset):
        raise NotImplementedError

    def _generate_result(self, asset):
        raise NotImplementedError

    def _get_vmaf_feature_assembler_instance(self, asset):
        vmaf_fassembler = FeatureAssembler(
            feature_dict={'VMAF_integer_feature': [self.FEATURE_NAME]},
            feature_option_dict=None,
            assets=[asset],
            logger=self.logger,
            fifo_mode=self.fifo_mode,
            delete_workdir=self.delete_workdir,
            result_store=self.result_store,
            optional_dict=self.optional_dict,
            optional_dict2=self.optional_dict2,
            parallelize=False,  # parallelization already in a higher level
            save_workfiles=self.save_workfiles,
        )
        return vmaf_fassembler

    @override(Executor)
    def _run_on_asset(self, asset):
        vmaf_fassembler = self._get_vmaf_feature_assembler_instance(asset)
        vmaf_fassembler.run()
        feature_result = vmaf_fassembler.results[0]

        result_dict = {
            self.get_scores_key(): feature_result[VmafIntegerFeatureExtractor.get_scores_key(self.FEATURE_NAME)]
        }

        return Result(asset, self.executor_id, result_dict)

    @override(Executor)
    def _remove_result(self, asset):
        # Override Executor._remove_result(self, asset) by redirecting it to the
        # FeatureAssembler.

        vmaf_fassembler = self._get_vmaf_feature_assembler_instance(asset)
        vmaf_fassembler.remove_results()


class VifQualityRunner(VmafSingleFeatureQualityRunner):
    TYPE = 'VIF'
    FEATURE_NAME = 'vif'


class Vif2QualityRunner(VmafSingleFeatureQualityRunner):
    TYPE = 'VIF2'
    FEATURE_NAME = 'vif2'


class Adm2QualityRunner(VmafSingleFeatureQualityRunner):
    TYPE = 'ADM2'
    # TYPE = 'DLM'
    FEATURE_NAME = 'adm2'


class VifScale0QualityRunner(VmafSingleFeatureQualityRunner):
    TYPE = 'VIF_SCALE0'
    FEATURE_NAME = 'vif_scale0'


class VifScale1QualityRunner(VmafSingleFeatureQualityRunner):
    TYPE = 'VIF_SCALE1'
    FEATURE_NAME = 'vif_scale1'


class VifScale2QualityRunner(VmafSingleFeatureQualityRunner):
    TYPE = 'VIF_SCALE2'
    FEATURE_NAME = 'vif_scale2'


class VifScale3QualityRunner(VmafSingleFeatureQualityRunner):
    TYPE = 'VIF_SCALE3'
    FEATURE_NAME = 'vif_scale3'


class MotionQualityRunner(VmafSingleFeatureQualityRunner):
    TYPE = 'MOTION'
    # TYPE = 'TI'
    FEATURE_NAME = 'motion'


class Motion2QualityRunner(VmafSingleFeatureQualityRunner):
    TYPE = 'MOTION2'
    # TYPE = 'TI'
    FEATURE_NAME = 'motion2'


class BootstrapVmafQualityRunner(VmafQualityRunner):
    TYPE = "BOOTSTRAP_VMAF"
    VERSION = VmafQualityRunner.VERSION + '-' + 'M' + BootstrapLibsvmNusvrTrainTestModel.VERSION
    ALGO_VERSION = None

    # "vmaf_float_b_v0.6.3": plain bootstrapping, "vmaf_rb_v0.6.3": residue bootstrapping
    DEFAULT_MODEL_FILEPATH = VmafConfig.model_path("vmaf_b_v0.6.3.json")

    def _populate_result_dict(self, feature_result, pred_result):
        result_dict = {}
        result_dict.update(feature_result.result_dict)  # add feature result
        self._populate_default_scores(pred_result, result_dict)
        result_dict[self.get_all_models_scores_key()] = pred_result['ys_pred_all_models']  # add quality score from all models
        result_dict[self.get_bagging_scores_key()] = pred_result['ys_pred_bagging']  # add bagging quality score
        result_dict[self.get_stddev_scores_key()] = pred_result['ys_pred_stddev']  # add stddev of bootstrapped quality score
        result_dict[self.get_ci95_low_scores_key()] = pred_result['ys_pred_ci95_low']  # add ci95 of bootstrapped quality score
        result_dict[self.get_ci95_high_scores_key()] = pred_result['ys_pred_ci95_high']  # add ci95 of bootstrapped quality score
        return result_dict

    def _populate_default_scores(self, pred_result, result_dict):
        result_dict[self.get_scores_key()] = pred_result['ys_pred']  # add quality score

    @classmethod
    def predict_with_model(cls, model, xs, **kwargs):
        DELTA = 1e-2
        result = model.predict(xs)
        ys_pred_all_models = result['ys_label_pred_all_models']
        ys_pred = result['ys_label_pred']
        ys_pred_bagging = result['ys_label_pred_bagging']
        ys_pred_stddev = result['ys_label_pred_stddev']
        ys_pred_ci95_low = result['ys_label_pred_ci95_low']
        ys_pred_ci95_high = result['ys_label_pred_ci95_high']
        ys_pred_plus = ys_pred_bagging + DELTA
        ys_pred_minus = ys_pred_bagging - DELTA

        do_transform_score = cls._do_transform_score(model, kwargs)
        if do_transform_score:
            ys_pred_all_models = np.array([cls.transform_score(model, ys_pred_some_model) for ys_pred_some_model in ys_pred_all_models])
            ys_pred = cls.transform_score(model, ys_pred)
            ys_pred_bagging = cls.transform_score(model, ys_pred_bagging)
            ys_pred_plus = cls.transform_score(model, ys_pred_plus)
            ys_pred_minus = cls.transform_score(model, ys_pred_minus)
            ys_pred_ci95_low = cls.transform_score(model, ys_pred_ci95_low)
            ys_pred_ci95_high = cls.transform_score(model, ys_pred_ci95_high)
        else:
            pass

        if 'disable_clip_score' in kwargs and kwargs['disable_clip_score'] is True:
            pass
        else:
            ys_pred_all_models = np.array([cls.clip_score(model, ys_pred_some_model) for ys_pred_some_model in ys_pred_all_models])
            ys_pred = cls.clip_score(model, ys_pred)
            ys_pred_bagging = cls.clip_score(model, ys_pred_bagging)
            ys_pred_plus = cls.clip_score(model, ys_pred_plus)
            ys_pred_minus = cls.clip_score(model, ys_pred_minus)
            ys_pred_ci95_low = cls.clip_score(model, ys_pred_ci95_low)
            ys_pred_ci95_high = cls.clip_score(model, ys_pred_ci95_high)

        # stddev score transform is applied after transform, clip, or both, or neither
        slope = ((ys_pred_plus - ys_pred_minus) / (2.0 * DELTA))
        ys_pred_stddev = ys_pred_stddev * slope

        return {'ys_pred_all_models': ys_pred_all_models,
                'ys_pred': ys_pred,
                'ys_pred_bagging': ys_pred_bagging,
                'ys_pred_stddev': ys_pred_stddev,
                'ys_pred_ci95_low': ys_pred_ci95_low,
                'ys_pred_ci95_high': ys_pred_ci95_high,
                }

    def get_train_test_model_class(self):
        # overide VmafQualityRunner.get_train_test_model_class
        return BootstrapLibsvmNusvrTrainTestModel

    @classmethod
    def get_all_models_scores_key(cls):
        return cls.TYPE + '_all_models_scores'

    @classmethod
    def get_all_models_score_key(cls):
        return cls.TYPE + '_all_models_score'

    @classmethod
    def get_bagging_scores_key(cls):
        return cls.TYPE + '_bagging_scores'

    @classmethod
    def get_bagging_score_key(cls):
        return cls.TYPE + '_bagging_score'

    @classmethod
    def get_stddev_scores_key(cls):
        return cls.TYPE + '_stddev_scores'

    @classmethod
    def get_stddev_score_key(cls):
        return cls.TYPE + '_stddev_score'

    @classmethod
    def get_ci95_low_scores_key(cls):
        return cls.TYPE + '_ci95_low_scores'

    @classmethod
    def get_ci95_low_score_key(cls):
        return cls.TYPE + '_ci95_low_score'

    @classmethod
    def get_ci95_high_scores_key(cls):
        return cls.TYPE + '_ci95_high_scores'

    @classmethod
    def get_ci95_high_score_key(cls):
        return cls.TYPE + '_ci95_high_score'


class BaggingVmafQualityRunner(BootstrapVmafQualityRunner):
    TYPE = "BAGGING_VMAF"
    VERSION = VmafQualityRunner.VERSION + '-' + BootstrapLibsvmNusvrTrainTestModel.VERSION

    def _populate_result_dict(self, feature_result, pred_result):
        result_dict = super(BaggingVmafQualityRunner, self)._populate_result_dict(feature_result, pred_result)

        # override score with bagging score
        result_dict[self.get_scores_key()] = pred_result['ys_pred_bagging']

        return result_dict


class NiqeQualityRunner(QualityRunner):
    TYPE = 'NIQE'

    # VERSION = '0.1'
    VERSION = '0.2'  # update model to py3

    DEFAULT_MODEL_FILEPATH = VmafConfig.model_path('other_models', 'niqe_v0.1.pkl')

    DEFAULT_FEATURE_DICT = {'NIQE_noref_feature': 'all'}

    def _get_quality_scores(self, asset):
        raise NotImplementedError

    def _generate_result(self, asset):
        raise NotImplementedError

    def _get_niqe_feature_assembler_instance(self, asset):

        # load TrainTestModel only to retrieve its 'feature_dict' extra info
        model = self._load_model(asset)

        # need this so that FeatureAssembler can find NiqeNorefFeatureExtractor:
        from vmaf.core.noref_feature_extractor import NiqeNorefFeatureExtractor

        feature_dict = model.get_appended_info('feature_dict')
        if feature_dict is None:
            feature_dict = self.DEFAULT_FEATURE_DICT

        feature_optional_dict = model.get_appended_info('feature_optional_dict')

        vmaf_fassembler = FeatureAssembler(
            feature_dict=feature_dict,
            feature_option_dict=None,
            assets=[asset],
            logger=self.logger,
            fifo_mode=self.fifo_mode,
            delete_workdir=self.delete_workdir,
            result_store=self.result_store,
            optional_dict=feature_optional_dict,
            optional_dict2=None,
            parallelize=False,
            save_workfiles=self.save_workfiles,
        )

        return vmaf_fassembler

    def _load_model(self, asset):
        if self.optional_dict is not None \
                and 'model_filepath' in self.optional_dict \
                and self.optional_dict['model_filepath'] is not None:
            model_filepath = self.optional_dict['model_filepath']
        else:
            model_filepath = self.DEFAULT_MODEL_FILEPATH
        model = TrainTestModel.from_file(model_filepath, self.logger)
        return model

    @override(Executor)
    def _run_on_asset(self, asset):
        # Override Executor._run_on_asset(self, asset), which runs a
        # FeatureAssembler, collect a feature vector, run
        # TrainTestModel.predict() on it, and return a Result object
        # (in this case, both Executor._run_on_asset(self, asset) and
        # QualityRunner._read_result(self, asset) get bypassed.

        niqe_fassembler = self._get_niqe_feature_assembler_instance(asset)
        niqe_fassembler.run()
        feature_result = niqe_fassembler.results[0]

        # xs = NiqeTrainTestModel.get_perframe_xs_from_result(feature_result)
        xs = NiqeTrainTestModel.get_xs_from_results([feature_result])

        model = self._load_model(asset)

        ys_pred = model.predict(xs)['ys_label_pred']

        result_dict = {}
        # add all feature result
        result_dict.update(feature_result.result_dict)
        # add quality score
        result_dict[self.get_scores_key()] = ys_pred

        return Result(asset, self.executor_id, result_dict)

    @override(Executor)
    def _remove_result(self, asset):
        # Override Executor._remove_result(self, asset) by redirecting it to the
        # FeatureAssembler.

        vmaf_fassembler = self._get_niqe_feature_assembler_instance(asset)
        vmaf_fassembler.remove_results()


class VmafexecQualityRunner(QualityRunner, FeatureDiscoveryMixin):

    TYPE = 'VMAFEXEC'

    VERSION = 'F' + VmafFeatureExtractor.VERSION + '-0.6.1'
    ALGO_VERSION = 2

    DEFAULT_MODEL_FILEPATH = vmaf.model_path("vmaf_v0.6.1.json")

    FEATURES = [
                'adm2', 'motion2', 'vif_scale0', 'vif_scale1', 'vif_scale2', 'vif_scale3',
                'adm_scale0', 'adm_scale1','adm_scale2','adm_scale3', 'motion',
                'float_psnr', 'psnr_y', 'psnr_cb', 'psnr_cr',
                'float_ssim', 'float_ms_ssim', 'ssim', 'ms_ssim',
                'float_moment_ref1st', 'float_moment_dis1st', 'float_moment_ref2nd', 'float_moment_dis2nd',
                ]

    @classmethod
    def get_feature_scores_key(cls, atom_feature):
        return "{type}_{atom_feature}_scores".format(
            type=cls.TYPE, atom_feature=atom_feature)

    def _generate_result(self, asset):
        # routine to call the command-line executable and generate quality
        # scores in the log file.

        log_file_path = self._get_log_file_path(asset)

        if self.optional_dict is not None \
                and 'models' in self.optional_dict \
                and self.optional_dict['models'] is not None:
            assert isinstance(self.optional_dict['models'], list)
            models = self.optional_dict['models']
        elif self.optional_dict is not None and 'use_default_built_in_model' in self.optional_dict:
            use_default_built_in_model = self.optional_dict['use_default_built_in_model']
            assert isinstance(use_default_built_in_model, bool)
            if use_default_built_in_model:
                models = []
        else:
            model0 = []
            model0.append(f'name=vmaf')

            if self.optional_dict is not None and 'model_filepath' in self.optional_dict:
                model_filepath = self.optional_dict['model_filepath']
            else:
                model_filepath = self.DEFAULT_MODEL_FILEPATH
            assert isinstance(model_filepath, str)
            model0.append(f'path={model_filepath}')

            if self.optional_dict is not None and 'disable_clip_score' in self.optional_dict:
                disable_clip_score = self.optional_dict['disable_clip_score']
            else:
                disable_clip_score = False
            assert isinstance(disable_clip_score, bool)
            if disable_clip_score:
                model0.append('disable_clip')

            models = [':'.join(model0)]

        if self.optional_dict is not None and 'float_psnr' in self.optional_dict:
            float_psnr = self.optional_dict['float_psnr']
        else:
            float_psnr = False
        assert isinstance(float_psnr, bool)

        if self.optional_dict is not None and 'float_ssim' in self.optional_dict:
            float_ssim = self.optional_dict['float_ssim']
        else:
            float_ssim = False
        assert isinstance(float_ssim, bool)

        if self.optional_dict is not None and 'float_ms_ssim' in self.optional_dict:
            float_ms_ssim = self.optional_dict['float_ms_ssim']
        else:
            float_ms_ssim = False
        assert isinstance(float_ms_ssim, bool)

        if self.optional_dict is not None and 'float_moment' in self.optional_dict:
            float_moment = self.optional_dict['float_moment']
        else:
            float_moment = False
        assert isinstance(float_ssim, bool)

        if self.optional_dict is not None and 'psnr' in self.optional_dict:
            psnr = self.optional_dict['psnr']
        else:
            psnr = False
        assert isinstance(psnr, bool)

        if self.optional_dict is not None and 'ssim' in self.optional_dict:
            ssim = self.optional_dict['ssim']
        else:
            ssim = False
        assert isinstance(ssim, bool)

        if self.optional_dict is not None and 'ms_ssim' in self.optional_dict:
            ms_ssim = self.optional_dict['ms_ssim']
        else:
            ms_ssim = False
        assert isinstance(ms_ssim, bool)

        if self.optional_dict is not None and 'no_prediction' in self.optional_dict:
            no_prediction = self.optional_dict['no_prediction']
        else:
            no_prediction = False
        assert isinstance(no_prediction, bool)

        if self.optional_dict is not None and 'subsample' in self.optional_dict:
            subsample = self.optional_dict['subsample']
        else:
            subsample = 1
        assert isinstance(subsample, int) and subsample >= 1

        if self.optional_dict is not None and 'n_threads' in self.optional_dict:
            n_threads = self.optional_dict['n_threads']
        else:
            n_threads = 1
        assert isinstance(n_threads, int) and n_threads >= 1

        if self.optional_dict is not None and 'disable_avx' in self.optional_dict:
            disable_avx = self.optional_dict['disable_avx']
        else:
            disable_avx = False
        assert isinstance(disable_avx, bool)

        disable_enhn_gain = self.optional_dict['disable_enhn_gain'] \
            if self.optional_dict is not None and 'disable_enhn_gain' in self.optional_dict else None
        assert disable_enhn_gain is None or isinstance(disable_enhn_gain, bool)

        vif_enhn_gain_limit = self.optional_dict['vif_enhn_gain_limit'] \
            if self.optional_dict is not None and 'vif_enhn_gain_limit' in self.optional_dict else None
        assert vif_enhn_gain_limit is None or isinstance(vif_enhn_gain_limit, int) or isinstance(vif_enhn_gain_limit, float)

        adm_enhn_gain_limit = self.optional_dict['adm_enhn_gain_limit'] \
            if self.optional_dict is not None and 'adm_enhn_gain_limit' in self.optional_dict else None
        assert adm_enhn_gain_limit is None or isinstance(adm_enhn_gain_limit, int) or isinstance(adm_enhn_gain_limit, float)

        assert (disable_enhn_gain is None) or \
               (disable_enhn_gain is not None and vif_enhn_gain_limit is None and adm_enhn_gain_limit is None)

        if self.optional_dict is not None and 'motion_force_zero' in self.optional_dict:
            motion_force_zero = self.optional_dict['motion_force_zero']
        else:
            motion_force_zero = False
        assert isinstance(motion_force_zero, bool)

        # ==== translate disable_enhn_gain into vif_enhn_gain_limit and adm_enhn_gain_limit: ====
        if disable_enhn_gain is None:
            pass
        elif disable_enhn_gain is not None and vif_enhn_gain_limit is None and adm_enhn_gain_limit is None:
            if disable_enhn_gain is True:
                vif_enhn_gain_limit = 1.0
                adm_enhn_gain_limit = 1.0
            else:
                pass
        else:
            assert False

        quality_width, quality_height = asset.quality_width_height

        fmt = self._get_workfile_yuv_type(asset)

        ref_path = asset.ref_procfile_path
        dis_path = asset.dis_procfile_path

        reference = ref_path
        distorted = dis_path
        width = quality_width
        height = quality_height
        pixel_format, bitdepth = convert_pixel_format_ffmpeg2vmafexec(fmt)
        output = log_file_path
        exe = self._get_exec()
        logger = self.logger

        ExternalProgramCaller.call_vmafexec(reference, distorted, width, height, pixel_format, bitdepth,
                                          float_psnr, psnr, float_ssim, ssim, float_ms_ssim, ms_ssim, float_moment,
                                          no_prediction, models, subsample, n_threads, disable_avx, output, exe, logger,
                                          vif_enhn_gain_limit, adm_enhn_gain_limit, motion_force_zero)

    def _get_exec(self):
        return None  # signaling default

    def _get_quality_scores(self, asset):
        # routine to read the quality scores from the log file, and return
        # the scores in a dictionary format.

        log_file_path = self._get_log_file_path(asset)
        tree = ElementTree.parse(log_file_path)
        root = tree.getroot()
        scores_dict = {}

        feature_scores = [[] for _ in self.FEATURES]
        feature_nicknames = [None for _ in self.FEATURES]

        if self.optional_dict is not None and 'no_prediction' in self.optional_dict:
            no_prediction = self.optional_dict['no_prediction']
        else:
            no_prediction = False
        assert isinstance(no_prediction, bool)

        # if no_prediction, scores keys are empty
        # if >=1 models are passed in through optional_dict, assign keys from model's name
        # else default to a single key as "vmaf"
        if no_prediction:
            scores_keys = []
        elif self.optional_dict is not None and 'models' in self.optional_dict:
            assert isinstance(self.optional_dict['models'], list)
            scores_keys = []
            for model in self.optional_dict['models']:
                scores_keys.append(model.split("name=")[1].split(":")[0])
        else:
            scores_keys = ['vmaf']

        for scores_key in scores_keys:
            scores_dict[scores_key] = []

        for frame in root.findall('frames/frame'):
            if not no_prediction:
                for scores_key in scores_keys:
                    scores_dict[scores_key].append(float(frame.attrib[scores_key]))
            for i_feature, feature in enumerate(self.FEATURES):

                # first look for exact match integer_xxx
                feature_found = self._discover_feature_exact(
                    frame, i_feature,
                    'integer_' + feature, feature,
                    feature_scores, feature_nicknames)

                if feature_found:
                    continue

                # look for exact match xxx
                feature_found = self._discover_feature_exact(
                    frame, i_feature,
                    feature, feature,
                    feature_scores, feature_nicknames)

                if feature_found:
                    continue

                # wildcard discovery: look for integer_xxx_*
                feature_found = self._discover_feature_wildcard(
                    frame, i_feature,
                    'integer_' + feature + '_', feature,
                    feature_scores, feature_nicknames)

                if feature_found:
                    continue

                # wildcard discovery: look for xxx_*
                feature_found = self._discover_feature_wildcard(
                    frame, i_feature,
                    feature + '_', feature,
                    feature_scores, feature_nicknames)

        for scores_key in scores_keys:
            assert len(scores_dict[scores_key]) != 0 \
                   or any([len(feature_score) != 0 for feature_score in feature_scores])
        quality_result = {}
        for scores_key in scores_keys:
            if scores_key != 'vmaf':
                quality_result[self.get_feature_scores_key(scores_key)] = scores_dict[scores_key]
            else:
                quality_result[self.get_scores_key()] = scores_dict[scores_key]

        for i_feature, feature in enumerate(self.FEATURES):
            if len(feature_scores[i_feature]) != 0:
                assert feature_nicknames[i_feature] is not None
                quality_result[self.get_feature_scores_key(feature_nicknames[i_feature])] = feature_scores[i_feature]
        return quality_result


if __name__ == '__main__':
    import doctest
    doctest.testmod()
